# Copyright 2021 Huawei Technologies Co., Ltd.All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loading dataset."""
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.transforms.c_transforms as deC


def load_dataset(dataset_path, batch_size=32):
    """Create dataset instance."""
    ds = de.MindDataset(dataset_path,
                        columns_list=["input_ids", "attention_mask",
                                      "token_type_ids", "label"],
                        shuffle=True, num_parallel_workers=8)
    type_cast_op = deC.TypeCast(mstype.int32)
    ds = ds.map(operations=type_cast_op, input_columns="input_ids")
    ds = ds.map(operations=type_cast_op, input_columns="attention_mask")
    ds = ds.map(operations=type_cast_op, input_columns="token_type_ids")
    ds = ds.map(operations=type_cast_op, input_columns="label")
    ds = ds.batch(batch_size, drop_remainder=True)
    return ds
